package org.west.sky.scripture.imports.core.websocket;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.util.Map;

/**
 * @author chenghz
 * @date 2023/2/10 13:36
 * @description: WebSocket握手拦截器
 */
public class WebSocketHandshakeInterceptor implements HandshakeInterceptor {

    private final static Logger LOGGER = LoggerFactory.getLogger(WebSocketHandshakeInterceptor.class);


    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> attributes) throws Exception {
        if (request instanceof ServletServerHttpRequest) {
            String path = request.getURI().getPath();
            if (requestIsValid(path)) {
                String[] params = getParams(path);
                attributes.put("WEBSOCKET_BUSINESS", params[0]);
                attributes.put("WEBSOCKET_SUBCLASS", params[1]);
                attributes.put("WEBSOCKET_KEY", params[2]);
            }
        }
        return true;
    }

    @Override
    public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {

    }

    /**
     * 校验url
     * /equipmentWs/业务/key
     *
     * @param url
     * @return
     */
    private boolean requestIsValid(String url) {
        //在这里可以写上具体的鉴权逻辑
        boolean isValid = StringUtils.hasText(url) && url.startsWith("/importWs/");
        return isValid;
    }

    /**
     * 获取参数
     *
     * @param url
     * @return
     */
    private String[] getParams(String url) {
        url = url.replace("/importWs/", "");
        return url.split("/");
    }
}
